{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CIFAR-10 Recipe\n",
    "In this notebook, we will show how to train a state-of-art CIFAR-10 network with MXNet and extract feature from the network.\n",
    "This example wiil cover\n",
    "\n",
    "- Network/Data definition \n",
    "- Multi GPU training\n",
    "- Model saving and loading\n",
    "- Prediction/Extracting Feature\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "import logging\n",
    "import numpy as np\n",
    "\n",
    "# setup logging\n",
    "logger = logging.getLogger()\n",
    "logger.setLevel(logging.DEBUG)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let's make some helper function to let us build a simplified Inception Network. More details about how to composite symbol into component can be found at [composite_symbol](composite_symbol.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Basic Conv + BN + ReLU factory\n",
    "def ConvFactory(data, num_filter, kernel, stride=(1,1), pad=(0, 0), act_type=\"relu\"):\n",
    "    # there is an optional parameter ```wrokshpace``` may influece convolution performance\n",
    "    # default, the workspace is set to 256(MB)\n",
    "    # you may set larger value, but convolution layer only requires its needed but not exactly\n",
    "    # MXNet will handle reuse of workspace without parallelism conflict\n",
    "    conv = mx.symbol.Convolution(data=data, workspace=256,\n",
    "                                 num_filter=num_filter, kernel=kernel, stride=stride, pad=pad)\n",
    "    bn = mx.symbol.BatchNorm(data=conv)\n",
    "    act = mx.symbol.Activation(data = bn, act_type=act_type)\n",
    "    return act"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# A Simple Downsampling Factory\n",
    "def DownsampleFactory(data, ch_3x3):\n",
    "    # conv 3x3\n",
    "    conv = ConvFactory(data=data, kernel=(3, 3), stride=(2, 2), num_filter=ch_3x3, pad=(1, 1))\n",
    "    # pool\n",
    "    pool = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(1,1), pool_type='max')\n",
    "    # concat\n",
    "    concat = mx.symbol.Concat(*[conv, pool])\n",
    "    return concat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# A Simple module\n",
    "def SimpleFactory(data, ch_1x1, ch_3x3):\n",
    "    # 1x1\n",
    "    conv1x1 = ConvFactory(data=data, kernel=(1, 1), pad=(0, 0), num_filter=ch_1x1)\n",
    "    # 3x3\n",
    "    conv3x3 = ConvFactory(data=data, kernel=(3, 3), pad=(1, 1), num_filter=ch_3x3)\n",
    "    #concat\n",
    "    concat = mx.symbol.Concat(*[conv1x1, conv3x3])\n",
    "    return concat"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can build a network with these component factories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "data = mx.symbol.Variable(name=\"data\")\n",
    "conv1 = ConvFactory(data=data, kernel=(3,3), pad=(1,1), num_filter=96, act_type=\"relu\")\n",
    "in3a = SimpleFactory(conv1, 32, 32)\n",
    "in3b = SimpleFactory(in3a, 32, 48)\n",
    "in3c = DownsampleFactory(in3b, 80)\n",
    "in4a = SimpleFactory(in3c, 112, 48)\n",
    "in4b = SimpleFactory(in4a, 96, 64)\n",
    "in4c = SimpleFactory(in4b, 80, 80)\n",
    "in4d = SimpleFactory(in4c, 48, 96)\n",
    "in4e = DownsampleFactory(in4d, 96)\n",
    "in5a = SimpleFactory(in4e, 176, 160)\n",
    "in5b = SimpleFactory(in5a, 176, 160)\n",
    "pool = mx.symbol.Pooling(data=in5b, pool_type=\"avg\", kernel=(7,7), name=\"global_avg\")\n",
    "flatten = mx.symbol.Flatten(data=pool)\n",
    "fc = mx.symbol.FullyConnected(data=flatten, num_hidden=10)\n",
    "softmax = mx.symbol.SoftmaxOutput(name='softmax',data=fc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# If you'd like to see the network structure, run the plot_network function\n",
    "#mx.viz.plot_network(symbol=softmax,node_attrs={'shape':'oval','fixedsize':'false'}) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# We will make model with current current symbol\n",
    "# For demo purpose, this model only train 1 epoch\n",
    "# We will use the first GPU to do training\n",
    "num_epoch = 1\n",
    "model = mx.model.FeedForward(ctx=mx.gpu(), symbol=softmax, num_epoch=num_epoch,\n",
    "                             learning_rate=0.05, momentum=0.9, wd=0.00001)\n",
    "\n",
    "# we can add learning rate scheduler to the model\n",
    "# model = mx.model.FeedForward(ctx=mx.gpu(), symbol=softmax, num_epoch=num_epoch,\n",
    "#                              learning_rate=0.05, momentum=0.9, wd=0.00001,\n",
    "#                              lr_scheduler=mx.misc.FactorScheduler(2))\n",
    "# In this example. learning rate will be reduced to 0.1 * previous learning rate for every 2 epochs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we have multiple GPU, for eaxmple, 4 GPU, we can utilize them without any difficulty"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# num_devs = 4\n",
    "# model = mx.model.FeedForward(ctx=[mx.gpu(i) for i in range(num_devs)], symbol=softmax, num_epoch = 1,\n",
    "#                              learning_rate=0.05, momentum=0.9, wd=0.00001)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next step is declaring data iterator. The original CIFAR-10 data is 3x32x32 in binary format, we provides RecordIO format, so we can use Image RecordIO format. For more infomation about Image RecordIO Iterator, check [document](https://mxnet.readthedocs.org/en/latest/python/io.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Use utility function in test to download the data\n",
    "# or manualy prepar\n",
    "import sys\n",
    "sys.path.append(\"../../tests/python/common\") # change the path to mxnet's tests/\n",
    "import get_data\n",
    "get_data.GetCifar10()\n",
    "# After we get the data, we can declare our data iterator\n",
    "# The iterator will automatically create mean image file if it doesn't exist\n",
    "batch_size = 128\n",
    "total_batch = 50000 / 128 + 1\n",
    "# Train iterator make batch of 128 image, and random crop each image into 3x28x28 from original 3x32x32\n",
    "train_dataiter = mx.io.ImageRecordIter(\n",
    "        shuffle=True,\n",
    "        path_imgrec=\"data/cifar/train.rec\",\n",
    "        mean_img=\"data/cifar/cifar_mean.bin\",\n",
    "        rand_crop=True,\n",
    "        rand_mirror=True,\n",
    "        data_shape=(3,28,28),\n",
    "        batch_size=batch_size,\n",
    "        preprocess_threads=1)\n",
    "# test iterator make batch of 128 image, and center crop each image into 3x28x28 from original 3x32x32\n",
    "# Note: We don't need round batch in test because we only test once at one time\n",
    "test_dataiter = mx.io.ImageRecordIter(\n",
    "        path_imgrec=\"data/cifar/test.rec\",\n",
    "        mean_img=\"data/cifar/cifar_mean.bin\",\n",
    "        rand_crop=False,\n",
    "        rand_mirror=False,\n",
    "        data_shape=(3,28,28),\n",
    "        batch_size=batch_size,\n",
    "        round_batch=False,\n",
    "        preprocess_threads=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Now we can fit the model with data. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Start training with [gpu(0)]\n",
      "INFO:root:Iter[0] Batch [50]\tSpeed: 1053.96 samples/sec\n",
      "INFO:root:Iter[0] Batch [100]\tSpeed: 1021.90 samples/sec\n",
      "INFO:root:Iter[0] Batch [150]\tSpeed: 1020.08 samples/sec\n",
      "INFO:root:Iter[0] Batch [200]\tSpeed: 1017.71 samples/sec\n",
      "INFO:root:Iter[0] Batch [250]\tSpeed: 1008.16 samples/sec\n",
      "INFO:root:Iter[0] Batch [300]\tSpeed: 1011.40 samples/sec\n",
      "INFO:root:Iter[0] Batch [350]\tSpeed: 995.93 samples/sec\n",
      "INFO:root:Epoch[0] Train-accuracy=0.719769\n",
      "INFO:root:Epoch[0] Time cost=50.322\n",
      "INFO:root:Epoch[0] Validation-accuracy=0.660008\n"
     ]
    }
   ],
   "source": [
    "model.fit(X=train_dataiter,\n",
    "          eval_data=test_dataiter,\n",
    "          eval_metric=\"accuracy\",\n",
    "          batch_end_callback=mx.callback.Speedometer(batch_size))\n",
    "\n",
    "# if we want to save model after every epoch, we can add check_point call back\n",
    "# model_prefix = './cifar_'\n",
    "# model.fit(X=train_dataiter,\n",
    "#           eval_data=test_dataiter,\n",
    "#           eval_metric=\"accuracy\",\n",
    "#           batch_end_callback=mx.helper.Speedometer(batch_size),\n",
    "#           epoch_end_callback=mx.callback.do_checkpoint(model_prefix))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After only 1 epoch, our model is able to acheive about 65% accuracy on testset(If not, try more times).\n",
    "We can save our model by calling either ```save``` or using ```pickle```.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Saved checkpoint to \"cifar10-0001.params\"\n"
     ]
    }
   ],
   "source": [
    "# using pickle\n",
    "import pickle\n",
    "smodel = pickle.dumps(model)\n",
    "# using saving (recommended)\n",
    "# We get the benefit being able to directly load/save from cloud storage(S3, HDFS)\n",
    "prefix = \"cifar10\"\n",
    "model.save(prefix)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To load saved model, you can use ```pickle``` if the model is generated by ```pickle```, or use ```load``` if it is generated by ```save```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# use pickle\n",
    "model2 = pickle.loads(smodel)\n",
    "# using load method (able to load from S3/HDFS directly)\n",
    "model3 = mx.model.FeedForward.load(prefix, num_epoch, ctx=mx.gpu())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use the model to do prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:Finish predict...\n",
      "INFO:root:final accuracy = 0.659900\n"
     ]
    }
   ],
   "source": [
    "prob = model3.predict(test_dataiter)\n",
    "logging.info('Finish predict...')\n",
    "# Check the accuracy from prediction\n",
    "test_dataiter.reset()\n",
    "# get label\n",
    "# Because the iterator pad each batch same shape, we want to remove paded samples here\n",
    "\n",
    "y_batch = []\n",
    "for dbatch in test_dataiter:\n",
    "    label = dbatch.label[0].asnumpy()\n",
    "    pad = test_dataiter.getpad()\n",
    "    real_size = label.shape[0] - pad\n",
    "    y_batch.append(label[0:real_size])\n",
    "y = np.concatenate(y_batch)\n",
    "\n",
    "# get prediction label from \n",
    "py = np.argmax(prob, axis=1)\n",
    "acc1 = float(np.sum(py == y)) / len(y)\n",
    "logging.info('final accuracy = %f', acc1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "From any symbol, we are able to know its internal feature_maps and bind a new model to extract that feature map"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(10000, 336, 1, 1)\n"
     ]
    }
   ],
   "source": [
    "# Predict internal featuremaps\n",
    "# From a symbol, we are able to get all internals. Note it is still a symbol\n",
    "internals = softmax.get_internals()\n",
    "# We get get an internal symbol for the feature.\n",
    "# By default, the symbol is named as \"symbol_name + _output\"\n",
    "# in this case we'd like to get global_avg\" layer's output as feature, so its \"global_avg_output\"\n",
    "# You may call ```internals.list_outputs()``` to find the target\n",
    "# but we strongly suggests set a special name for special symbol \n",
    "fea_symbol = internals[\"global_avg_output\"]\n",
    "\n",
    "# Make a new model by using an internal symbol. We can reuse all parameters from model we trained before\n",
    "# In this case, we must set ```allow_extra_params``` to True \n",
    "# Because we don't need params of FullyConnected Layer\n",
    "\n",
    "feature_extractor = mx.model.FeedForward(ctx=mx.gpu(), symbol=fea_symbol, \n",
    "                                         arg_params=model.arg_params,\n",
    "                                         aux_params=model.aux_params,\n",
    "                                         allow_extra_params=True)\n",
    "# Predict as normal\n",
    "global_pooling_feature = feature_extractor.predict(test_dataiter)\n",
    "print(global_pooling_feature.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
